# -*- coding: utf-8 -*-
"""The user agent class."""
from typing import Type, Any

from pydantic import BaseModel

from ._agent_base import AgentBase
from ._user_input import UserInputBase, TerminalUserInput
from ..message import Msg


class UserAgent(AgentBase):
    """The class for user interaction, allowing developers to handle the user
    input from different sources, such as web UI, cli, and other interfaces.
    """

    _input_method: UserInputBase = TerminalUserInput()
    """The user input method, can be overridden by calling the
    `register_instance/class_input_method` function."""

    def __init__(
        self,
        name: str,
    ) -> None:
        """Initialize the user agent with a name."""
        super().__init__()

        self.name = name

    async def reply(
        self,
        msg: Msg | list[Msg] | None = None,
        structured_model: Type[BaseModel] | None = None,
    ) -> Msg:
        """Receive input message(s) and generate a reply message from the user.

        Args:
            msg (`Msg | list[Msg] | None`, defaults to `None`):
                The message(s) to be replied. If `None`, the agent will wait
                for user input.
            structured_model (`Type[BaseModel] | None`, defaults to `None`):
                A child class of `pydantic.BaseModel` that defines the
                structured output format. If provided, the user will be
                prompted to fill in the required fields.

        Returns:
            `Msg`:
                The reply message generated by the user.
        """

        # Get the input from the specified input method.
        input_data = await self._input_method(
            agent_id=self.id,
            agent_name=self.name,
            structured_model=structured_model,
        )

        blocks_input = input_data.blocks_input
        if (
            blocks_input
            and len(blocks_input) == 1
            and blocks_input[0].get("type") == "text"
        ):
            # Turn blocks_input into a string if only one text block exists
            blocks_input = blocks_input[0].get("text")

        msg = Msg(
            self.name,
            content=blocks_input,
            role="user",
            metadata=input_data.structured_input,
        )

        await self.print(msg)

        return msg

    def override_instance_input_method(
        self,
        input_method: UserInputBase,
    ) -> None:
        """Override the input method of the current UserAgent instance.

        Args:
            input_method (`UserInputBase`):
                The callable input method, which should be an object of a
                class that inherits from `UserInputBase`.
        """
        if not isinstance(input_method, UserInputBase):
            raise ValueError(
                f"The input method should be an instance of the child class "
                f"of `UserInputBase`, but got {type(input_method)} instead.",
            )
        self._input_method = input_method

    @classmethod
    def override_class_input_method(
        cls,
        input_method: UserInputBase,
    ) -> None:
        """Override the input method of the current UserAgent class.

        Args:
            input_method (`UserInputBase`):
                The callable input method, which should be an object of a
                class that inherits from `UserInputBase`.
        """
        if not isinstance(input_method, UserInputBase):
            raise ValueError(
                f"The input method should be an instance of the child class "
                f"of `UserInputBase`, but got {type(input_method)} instead.",
            )
        cls._input_method = input_method

    async def handle_interrupt(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> Msg:
        """The post-processing logic when the reply is interrupted by the
        user or something else."""
        raise NotImplementedError(
            f"The handle_interrupt function is not implemented in "
            f"{self.__class__.__name__}",
        )

    async def observe(self, msg: Msg | list[Msg] | None) -> None:
        """Observe the message(s) from the other agents or the environment."""
